K-means Segmentation

15. K-means Segmentation#

!pip install moviepy scikit-image scikit-learn
import os
import cv2
import numpy as np
import requests
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import plotly.express as px
from moviepy.editor import VideoFileClip, ImageSequenceClip
from sklearn.cluster import KMeans
from collections import Counter
from skimage.segmentation import slic
from skimage.util import img_as_float
from IPython.display import HTML, display
from base64 import b64encode
import ipywidgets as widgets
import pandas as pd
WARNING:py.warnings:/usr/local/lib/python3.10/dist-packages/moviepy/video/io/sliders.py:61: SyntaxWarning: "is" with a literal. Did you mean "=="?
  if event.key is 'enter':
def download_video(video_url, save_path):
    response = requests.get(video_url, stream=True)
    if response.status_code == 200:
        with open(save_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=1024):
                f.write(chunk)
        print(f"Video downloaded successfully and saved to: {save_path}")
    else:
        print(f"Failed to download video. Status code: {response.status_code}")
# Function to extract frames from a local video file
def extract_frames_from_video(video_path, output_dir, frame_rate=10, width=1024, height=1024):
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Load the video using moviepy
    clip = VideoFileClip(video_path)

    # Extract frames at the specified rate and resolution
    for i, frame in enumerate(clip.iter_frames(fps=frame_rate)):
        # Resize the frame
        resized_frame = cv2.resize(frame, (width, height))

        # Save the frame
        frame_path = os.path.join(output_dir, f'frame_{i:04d}.png')
        cv2.imwrite(frame_path, resized_frame)

    print(f"Frames extracted and saved to: {output_dir}")
# Function to apply KMeans clustering to an image
def apply_kmeans(image, n_clusters, kmeans_model=None):
    pixel_values = image.reshape((-1, 3))
    pixel_values = np.float32(pixel_values)

    if kmeans_model is None:
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        kmeans.fit(pixel_values)
    else:
        kmeans = kmeans_model

    labels = kmeans.predict(pixel_values)
    segmented_image = kmeans.cluster_centers_[labels]
    segmented_image = segmented_image.reshape(image.shape)
    segmented_image = np.uint8(segmented_image)

    return segmented_image, labels, kmeans
# Function to preprocess an image using CLAHE and HSV conversion
def preprocess_image(image):
    image_cropped = image[-400:, :, :]
    hsv_image = cv2.cvtColor(image_cropped, cv2.COLOR_BGR2HSV)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    hsv_image[:, :, 2] = clahe.apply(hsv_image[:, :, 2])
    preprocessed_image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
    return preprocessed_image
# Function to segment an image using SLIC algorithm
def segment_image(image, n_segments):
    image_float = img_as_float(image)
    segments = slic(image_float, n_segments=n_segments, compactness=10, start_label=0)
    return segments
# Function to process an image, apply KMeans, save output, and add cluster information
def process_and_save_image(image, kmeans_model, filename, output_dir):
    image_cropped = preprocess_image(image)
    segments = segment_image(image_cropped, n_segments=500)
    segmented_image, labels, _ = apply_kmeans(image_cropped, n_clusters=4, kmeans_model=kmeans_model)
    label_counts = Counter(labels)
    cluster_info = []
    total_pixels = image_cropped.shape[0] * image_cropped.shape[1]
    cluster_hsv_values = kmeans_model.cluster_centers_
    for i in range(4):
        cluster_percentage = (label_counts[i] / total_pixels) * 100
        cluster_hsv = cluster_hsv_values[i]
        cluster_info.append(f"Cluster {i}: {label_counts[i]} pixels ({cluster_percentage:.2f}%) - HSV: ({cluster_hsv[0]:.2f}, {cluster_hsv[1]:.2f}, {cluster_hsv[2]:.2f})")

    segmented_image_bgr = cv2.cvtColor(segmented_image, cv2.COLOR_HSV2BGR)
    output_image_path = os.path.join(output_dir, filename)
    cv2.imwrite(output_image_path, segmented_image_bgr)

    output_txt_path = os.path.join(output_dir, filename.rsplit('.', 1)[0] + '_clusters.txt')
    with open(output_txt_path, 'w') as f:
        f.write("\n".join(cluster_info))

    for i, info in enumerate(cluster_info):
        text_position = (10, 30 + i * 20)
        cv2.putText(segmented_image_bgr, info, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

    return segmented_image_bgr
# Function to process all frames and segment them
def process_frames(output_dir):
    frames = os.listdir(output_dir)
    if not frames:
        print("No frames available for processing.")
        return

    # Process the first image to initialize KMeans model
    first_image_path = os.path.join(output_dir, frames[0])
    first_image = cv2.imread(first_image_path)
    first_image_cropped = preprocess_image(first_image)
    segmented_image, labels, kmeans_model = apply_kmeans(first_image_cropped, n_clusters=4)

    # Process all frames and store processed frames in a list
    processed_frames_list = []  # Create a list to store processed frames
    for filename in frames:
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(output_dir, filename)
            image = cv2.imread(image_path)
            processed_frame = process_and_save_image(image, kmeans_model, filename, output_dir)
            processed_frames_list.append(processed_frame)  # Append processed frame to the list

    print("Segmentation completed, results saved.")
    return processed_frames_list  # Return the list of processed frames
# Function to create a video from processed frames
def create_video_from_frames(frames, output_video_path, fps=10):
    clip = ImageSequenceClip([cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames], fps=fps)
    clip.write_videofile(output_video_path, codec='libx264')
    print(f"Video saved to: {output_video_path}")
video_url = "https://github.com/atticus-carter/cv/raw/refs/heads/main/videos/output_video_8.avi"
video_path = "/content/2022SHRSubset.avi"
frames_output_dir = "/content/frames"  # Change to your desired output directory
segmented_video_output_path = "/content/segmented_videos/segmented_video.mp4"

os.makedirs(os.path.dirname(segmented_video_output_path), exist_ok=True)
os.makedirs(frames_output_dir, exist_ok=True)
os.makedirs("/content/segmented_videos", exist_ok=True)

# Download the video
download_video(video_url, video_path)

# Extract frames from local video
extract_frames_from_video(video_path, frames_output_dir)

# Process frames for segmentation
processed_frames = process_frames(frames_output_dir)

# Create video from processed frames
create_video_from_frames(processed_frames, segmented_video_output_path)
def display_cluster_colors_with_image(cluster_file_path, image_path, video_path):
    """Displays the colors of the clusters visually from a cluster text file,
    along with the original image and an original video frame clip using Plotly's imshow."""

    # Extract and display the first frame from the video
    clip = VideoFileClip(video_path)
    first_frame = clip.get_frame(0)  # Get the first frame

    # Cut it down to the bottom 400 pixels
    first_frame_cropped = first_frame[-400:, :, :]

    # Save the cropped image
    os.makedirs("/content", exist_ok=True)
    cv2.imwrite("/content/cropped_image.png", first_frame_cropped)

    # Convert BGR to RGB for display with plotly
    image_rgb = cv2.cvtColor(first_frame_cropped, cv2.COLOR_BGR2RGB)

    # Display the original image clip using Plotly
    fig_original = px.imshow(image_rgb)
    fig_original.update_layout(title="Original Image Clip")
    fig_original.show()

    cluster_file = os.path.join(cluster_file_path, "frame_0000_clusters.txt")
    with open(cluster_file, 'r') as f:
        lines = f.readlines()

    hsv_colors = []
    for line in lines:
        if 'HSV' in line:
            hsv_str = line.split('HSV: ')[1].strip('()\n').split(',')
            hsv_colors.append([float(x.strip()) for x in hsv_str])

    # Load the image
    image_file = os.path.join(image_path, "frame_0000.png")
    image2 = cv2.imread(image_file)
    image_rgb_clust = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)  # Convert to RGB

    # Create color swatches using matplotlib
    fig, ax = plt.subplots(1, len(hsv_colors), figsize=(5, 2))

    for i, hsv_color in enumerate(hsv_colors):
        bgr_color = cv2.cvtColor(np.uint8([[hsv_color]]), cv2.COLOR_HSV2BGR)[0][0]
        rgb_color = bgr_color[::-1]

        rect = patches.Rectangle((0, 0), 1, 1, facecolor=tuple(rgb_color / 255.0))
        ax[i].add_patch(rect)
        ax[i].axis('off')
        ax[i].set_title(f'Cluster {i}')

    plt.tight_layout()

    # Display the clustered image using Plotly's imshow
    fig_image = px.imshow(image_rgb_clust)
    fig_image.update_layout(title="Clustered Image")

    # Show both plots (color swatches and image)
    fig_image.show()
    plt.show()


cluster_file_path = frames_output_dir
image_path = frames_output_dir
video_path = video_path
frames_output_dir = frames_output_dir
segmented_video_output_path = segmented_video_output_path
display_cluster_colors_with_image(cluster_file_path, image_path, video_path)
{
    "tags": [
        "hide-input",
    ]
}
def show_video(video_path, width=600):
  mp4 = open(video_path,'rb').read()
  data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  return HTML("""
  <video width="{0}" controls>
        <source src="{1}" type="video/mp4">
  </video>
  """.format(width, data_url))

show_video(segmented_video_output_path)
{
    "tags": [
        "hide-input",
    ]
}

# New cell for data visualization
import plotly.graph_objs as go
import pandas as pd
import os  # Make sure os module is imported

# Gather cluster data from text files
frame_numbers = []
cluster_percent_cover = {
    'Cluster 0': [],
    'Cluster 1': [],
    'Cluster 2': [],
    'Cluster 3': []
}

# Get a list of all image files in the frames directory
frames_dir = frames_output_dir
image_files = [f for f in os.listdir(frames_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]  # Get all image files

# Parse cluster data from corresponding text files
for i in range(len(image_files)):  # Iterate over all image files
    frame_number = int(image_files[i].split('_')[1].split('.')[0])  # Extract frame number from filename
    txt_file_path = os.path.join(frames_dir, f"frame_{str(frame_number).zfill(4)}_clusters.txt")
    if not os.path.exists(txt_file_path):
        print(f"Warning: could not find {txt_file_path}")
        continue

    with open(txt_file_path, 'r') as file:
        data = file.readlines()
        # Extract cluster percentages (assuming 4 clusters)
        try:
            cluster_percent_cover['Cluster 0'].append(float(data[0].split('(')[1].split('%')[0].strip()))
            cluster_percent_cover['Cluster 1'].append(float(data[1].split('(')[1].split('%')[0].strip()))
            cluster_percent_cover['Cluster 2'].append(float(data[2].split('(')[1].split('%')[0].strip()))
            cluster_percent_cover['Cluster 3'].append(float(data[3].split('(')[1].split('%')[0].strip()))
        except IndexError:
            print(f"Warning: insufficient data in {txt_file_path}")
            continue

    frame_numbers.append(frame_number)

# Create a DataFrame for easier plotting
df = pd.DataFrame(cluster_percent_cover, index=frame_numbers)
df.index.name = 'Frame Number'

# Plotting the stacked bar chart using Plotly
fig = go.Figure()
fig.add_trace(go.Bar(
    x=df.index,
    y=df['Cluster 0'],
    name='Cluster 0'
))
fig.add_trace(go.Bar(
    x=df.index,
    y=df['Cluster 1'],
    name='Cluster 1'
))
fig.add_trace(go.Bar(
    x=df.index,
    y=df['Cluster 2'],
    name='Cluster 2'
))
fig.add_trace(go.Bar(
    x=df.index,
    y=df['Cluster 3'],
    name='Cluster 3'
))

fig.update_layout(
    barmode='stack',
    title='Percent Cover of Each Cluster Over Time',
    xaxis_title='Frame Number',
    yaxis_title='Percent Cover',
    template='plotly_white',
    hovermode='x unified',
    legend_title='Clusters'
)

fig.show()
{
    "tags": [
        "hide-input",
    ]
}

# Create text boxes for renaming clusters
rename_widgets = [
    widgets.Text(value=f'Cluster {i}', description=f'Cluster {i}:') for i in range(4)
]

# Display widgets for renaming
print("Enter new names for the clusters:")
for w in rename_widgets:
    display(w)

# Button to confirm changes
button = widgets.Button(description="Apply Changes")
output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()  # Clear previous output
        # Get new cluster names
        new_names = [w.value for w in rename_widgets]

        # Merge clusters if they have the same name
        unique_names = list(set(new_names))
        merged_clusters = {name: [] for name in unique_names}

        # Plotting the updated clusters
        frames_dir = frames_output_dir # Path to the directory containing the frames
        if not os.path.exists(frames_dir):
            print(f"Error: Frames directory {frames_dir} does not exist.")
            return

        image_files = [f for f in os.listdir(frames_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]  # Get all image files

        frame_numbers = []

        # Parse cluster data from corresponding text files
        for i in range(len(image_files)):  # Iterate over all image files
            frame_number = int(image_files[i].split('_')[1].split('.')[0])  # Extract frame number from filename
            txt_file_path = os.path.join(frames_dir, f"frame_{str(frame_number).zfill(4)}_clusters.txt")
            if not os.path.exists(txt_file_path):
                continue

            with open(txt_file_path, 'r') as file:
                data = file.readlines()
                # Extract cluster percentages (assuming 4 clusters)
                try:
                    cluster_data = [float(data[j].split('(')[1].split('%')[0].strip()) for j in range(4)]
                except IndexError:
                    continue

                # Merge clusters based on new names
                for j, name in enumerate(new_names):
                    if name in merged_clusters:
                        if len(merged_clusters[name]) <= len(frame_numbers):
                            merged_clusters[name].append(0)  # Ensure the list is the correct length
                        merged_clusters[name][-1] += cluster_data[j]

            frame_numbers.append(frame_number)

        # Ensure all merged cluster lists are the correct length
        for name in merged_clusters:
            while len(merged_clusters[name]) < len(frame_numbers):
                merged_clusters[name].append(0)

        # Store DataFrame for further use
        global mergeddf
        mergeddf = pd.DataFrame(merged_clusters, index=frame_numbers)
        mergeddf.index.name = 'Frame Number'
        print("Clusters have been renamed and merged. DataFrame saved as 'mergeddf'.")

# Attach click event to button
button.on_click(on_button_click)

# Display button and output
display(button, output)
import statsmodels.api as sm

try:
    mergeddf
except NameError:
    print("Error: 'mergeddf' is not defined. Please run the previous cell to generate it.")
else:
    fig = go.Figure()
    for column in mergeddf.columns:
        fig.add_trace(go.Bar(
            x=mergeddf.index,
            y=mergeddf[column],
            name=column
        ))

        # Add regression line and get equation and R-squared
        X = mergeddf.index.values.reshape(-1, 1)
        y = mergeddf[column].values
        X = sm.add_constant(X)
        model = sm.OLS(y, X).fit()
        predictions = model.predict(X)

        # Get equation
        intercept = model.params[0]
        slope = model.params[1]
        equation = f'y = {slope:.2f}x + {intercept:.2f}'

        # Get R-squared
        r_squared = model.rsquared

        fig.add_trace(go.Scatter(
            x=mergeddf.index,
            y=predictions,
            mode='lines',
            name=f'{column} Regression',
            line=dict(color='red')
        ))

        # Add annotation with equation and R-squared
        fig.add_annotation(
            x=mergeddf.index[-1],  # Position at the end of the x-axis
            y=predictions[-1],   # Position at the end of the regression line
            text=f'{equation}<br>R² = {r_squared:.2f}',
            showarrow=False,
            font=dict(size=12)
        )

    fig.update_layout(
        barmode='stack',
        title='Percent Cover of Each Cluster Over Time',
        xaxis_title='Frame Number',
        yaxis_title='Percent Cover',
        template='plotly_white',
        hovermode='x unified',
        legend_title='Clusters'
    )

    fig.show()
{
    "tags": [
        "hide-input",
    ]
}

# Ensure 'mergeddf' is defined
try:
    mergeddf
except NameError:
    print("Error: 'mergeddf' is not defined. Please run the previous cell to generate it.")
else:
    summary_stats = mergeddf.describe()
    print("Summary Statistics:\n", summary_stats)

    print("\nRunning Linear Regression Model:\n")
    X = mergeddf.index.values.reshape(-1, 1)  # Frame numbers as predictor
    for column in mergeddf.columns:
        y = mergeddf[column].values  # Cluster percentage cover as response
        X = sm.add_constant(X)  # Add constant to predictor
        model = sm.OLS(y, X).fit()
        predictions = model.predict(X)
        print(f"\nLinear Regression Summary for Cluster: {column}\n")
        print(model.summary())
Summary Statistics:
                bg         bio
count  331.000000  331.000000
mean    77.907311   22.092568
std      1.092282    1.092802
min     76.010000   19.190000
25%     76.910000   21.165000
50%     77.900000   22.100000
75%     78.835000   23.090000
max     80.810000   23.990000

Running Linear Regression Model:


Linear Regression Summary for Cluster: bg

                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.476
Model:                            OLS   Adj. R-squared:                  0.474
Method:                 Least Squares   F-statistic:                     298.8
Date:                Mon, 28 Oct 2024   Prob (F-statistic):           4.34e-48
Time:                        02:44:27   Log-Likelihood:                -391.44
No. Observations:                 331   AIC:                             786.9
Df Residuals:                     329   BIC:                             794.5
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         76.6080      0.087    881.996      0.000      76.437      76.779
x1             0.0079      0.000     17.286      0.000       0.007       0.009
==============================================================================
Omnibus:                        6.352   Durbin-Watson:                   2.007
Prob(Omnibus):                  0.042   Jarque-Bera (JB):                5.561
Skew:                           0.245   Prob(JB):                       0.0620
Kurtosis:                       2.595   Cond. No.                         380.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

Linear Regression Summary for Cluster: bio

                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.476
Model:                            OLS   Adj. R-squared:                  0.475
Method:                 Least Squares   F-statistic:                     299.4
Date:                Mon, 28 Oct 2024   Prob (F-statistic):           3.77e-48
Time:                        02:44:27   Log-Likelihood:                -391.46
No. Observations:                 331   AIC:                             786.9
Df Residuals:                     329   BIC:                             794.5
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         23.3931      0.087    269.315      0.000      23.222      23.564
x1            -0.0079      0.000    -17.302      0.000      -0.009      -0.007
==============================================================================
Omnibus:                        6.318   Durbin-Watson:                   2.007
Prob(Omnibus):                  0.042   Jarque-Bera (JB):                5.565
Skew:                          -0.246   Prob(JB):                       0.0619
Kurtosis:                       2.599   Cond. No.                         380.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.